// Adapted from
// https://github.com/NVIDIA/TensorRT-LLM/blob/v0.16.0/cpp/tensorrt_llm/kernels/cutlass_kernels/fp8_rowwise_gemm/fp8_rowwise_gemm_template.h
// https://github.com/NVIDIA/TensorRT-LLM/blob/v0.16.0/cpp/tensorrt_llm/kernels/cutlass_kernels/fp8_rowwise_gemm/fp8_rowwise_gemm_kernel_template_sm89.h
// https://github.com/NVIDIA/TensorRT-LLM/blob/v0.16.0/cpp/tensorrt_llm/kernels/cutlass_kernels/fp8_rowwise_gemm/fp8_rowwise_gemm_kernel_template_sm90.h

#include <ATen/cuda/CUDAContext.h>
#include <cudaTypedefs.h>
#include <cutlass/arch/arch.h>
#include <cutlass/arch/memory.h>
#include <cutlass/arch/mma.h>
#include <cutlass/array.h>
#include <cutlass/cutlass.h>
#include <cutlass/epilogue/thread/activation.h>
#include <cutlass/epilogue/thread/linear_combination.h>
#include <cutlass/epilogue/threadblock/default_thread_map_tensor_op.h>
#include <cutlass/gemm/device/gemm.h>
#include <cutlass/gemm/device/gemm_universal_adapter.h>
#include <cutlass/gemm/gemm.h>
#include <cutlass/gemm/kernel/default_gemm_universal_with_visitor.h>
#include <cutlass/gemm/thread/mma.h>
#include <cutlass/layout/matrix.h>
#include <cutlass/matrix_coord.h>
#include <cutlass/numeric_types.h>
#include <cutlass/tensor_ref.h>
#include <torch/all.h>

#include <cute/tensor.hpp>
#include <cutlass/epilogue/collective/collective_builder.hpp>
#include <cutlass/epilogue/collective/default_epilogue.hpp>
#include <cutlass/epilogue/threadblock/fusion/visitors.hpp>
#include <cutlass/gemm/collective/collective_builder.hpp>
#include <cutlass/gemm/dispatch_policy.hpp>
#include <cutlass/gemm/kernel/gemm_universal.hpp>
#include <cutlass/util/packed_stride.hpp>

#include "utils.h"

using namespace cute;

#if defined CUDA_VERSION && CUDA_VERSION >= 12040
template <typename ElementType, typename OutElementType, typename AccumElementType, typename CtaShape,
          typename WarpShape, int Stages, bool WithBias, typename FP8MathOperator = cutlass::arch::OpMultiplyAdd,
          template <typename...> typename EpilogueVisitor = cutlass::epilogue::threadblock::Sm80EVT,
          typename ThreadblockSwizzle = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>>
struct DeviceGemmFp8RowwiseSm89 {
  static_assert(std::is_same_v<ElementType, cutlass::float_e4m3_t>, "ElementType must be FP8(e4m3)");

  using ElementA = ElementType;
  using LayoutA = cutlass::layout::RowMajor;
  static constexpr int AlignmentA = 128 / cutlass::sizeof_bits<ElementA>::value;

  using ElementB = ElementType;
  using LayoutB = cutlass::layout::ColumnMajor;
  static constexpr int AlignmentB = 128 / cutlass::sizeof_bits<ElementB>::value;

  using ElementC = OutElementType;
  using LayoutC = cutlass::layout::RowMajor;
  static constexpr int AlignmentC = 128 / cutlass::sizeof_bits<ElementC>::value;

  using ElementOutput = OutElementType;
  using LayoutOutput = cutlass::layout::RowMajor;
  static constexpr int AlignmentOutput = 128 / cutlass::sizeof_bits<ElementOutput>::value;

  using ElementAccumulator = AccumElementType;
  using ElementComputeEpilogue = float;
  using ArchTag = cutlass::arch::Sm89;
  using OperatorClass = cutlass::arch::OpClassTensorOp;

  using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>;
  // Number of epilogue stages in EVT
  static constexpr int EVTEpilogueStages = 1;

  using OutputTileThreadMap = cutlass::epilogue::threadblock::OutputTileThreadLayout<CtaShape, WarpShape, ElementC,
                                                                                     AlignmentC, EVTEpilogueStages>;

  // Definition of EVT
  using accSrc = cutlass::epilogue::threadblock::VisitorAccFetch;

  using ComputeBScale = cutlass::epilogue::threadblock::VisitorCompute<
      cutlass::multiplies, ElementComputeEpilogue, ElementComputeEpilogue, cutlass::FloatRoundStyle::round_to_nearest>;
  using bScaleSrc = cutlass::epilogue::threadblock::VisitorRowBroadcast<OutputTileThreadMap, ElementComputeEpilogue,
                                                                        Stride<_0, _1, _0>>;
  using EpilogueBScale = cutlass::epilogue::threadblock::Sm80EVT<ComputeBScale, accSrc, bScaleSrc>;

  using ComputeAScale =
      cutlass::epilogue::threadblock::VisitorCompute<cutlass::multiplies, ElementC, ElementComputeEpilogue,
                                                     cutlass::FloatRoundStyle::round_to_nearest>;
  using aScaleSrc = cutlass::epilogue::threadblock::VisitorColBroadcast<OutputTileThreadMap, ElementComputeEpilogue,
                                                                        Stride<_1, _0, _0>>;
  using EpilogueAScale = cutlass::epilogue::threadblock::Sm80EVT<ComputeAScale, EpilogueBScale, aScaleSrc>;

  // With bias
  using biasSrc =
      cutlass::epilogue::threadblock::VisitorRowBroadcast<OutputTileThreadMap, ElementOutput, Stride<_0, _1, _0>>;
  using ComputeAScaleWithBias =
      cutlass::epilogue::threadblock::VisitorCompute<cutlass::multiply_add, ElementC, ElementComputeEpilogue,
                                                     cutlass::FloatRoundStyle::round_to_nearest>;
  using EpilogueAScaleWithBias =
      cutlass::epilogue::threadblock::Sm80EVT<ComputeAScaleWithBias, EpilogueBScale, aScaleSrc, biasSrc>;

  using dTar = cutlass::epilogue::threadblock::VisitorAuxStore<
      OutputTileThreadMap, ElementC, cutlass::FloatRoundStyle::round_to_nearest, Stride<int64_t, _1, _0>>;
  using EpilogueStore =
      typename cutlass::platform::conditional<WithBias,
                                              cutlass::epilogue::threadblock::Sm80EVT<dTar, EpilogueAScaleWithBias>,
                                              cutlass::epilogue::threadblock::Sm80EVT<dTar, EpilogueAScale>>::type;

  using EpilogueOp = EpilogueStore;

  using GemmKernel = typename cutlass::gemm::kernel::DefaultGemmWithVisitor<
      ElementA, LayoutA, cutlass::ComplexTransform::kNone, AlignmentA, ElementB, LayoutB,
      cutlass::ComplexTransform::kNone, AlignmentB, ElementC, LayoutC, AlignmentC, ElementAccumulator,
      ElementComputeEpilogue, OperatorClass, ArchTag, CtaShape, WarpShape, InstructionShape, EpilogueOp,
      ThreadblockSwizzle, Stages, FP8MathOperator, EVTEpilogueStages>::GemmKernel;

  using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
};

template <typename Gemm, bool WithBias>
typename Gemm::Arguments prepare_sm89_fp8_args(torch::Tensor& out, const torch::Tensor& a, const torch::Tensor& b,
                                               const torch::Tensor& scales_a, const torch::Tensor& scales_b,
                                               const c10::optional<torch::Tensor>& bias) {
  using ElementT = typename Gemm::ElementA;
  using ElementOutput = typename Gemm::ElementD;
  using ElementComputeEpilogue = float;

  int32_t m = a.size(0);
  int32_t n = b.size(1);
  int32_t k = a.size(1);

  int64_t lda = a.stride(0);
  int64_t ldb = b.stride(1);
  int64_t ldc = out.stride(0);

  ElementT const* ptr_a = reinterpret_cast<ElementT const*>(a.data_ptr());
  ElementT const* ptr_b = reinterpret_cast<ElementT const*>(b.data_ptr());
  ElementOutput const* ptr_bias = nullptr;
  if constexpr (WithBias) {
    TORCH_CHECK(bias.has_value())
    ptr_bias = reinterpret_cast<ElementOutput const*>(bias.value().data_ptr());
  }
  ElementOutput* ptr_d = reinterpret_cast<ElementOutput*>(out.data_ptr());
  ElementComputeEpilogue const* ptr_scales_a = reinterpret_cast<ElementComputeEpilogue const*>(scales_a.data_ptr());
  ElementComputeEpilogue const* ptr_scales_b = reinterpret_cast<ElementComputeEpilogue const*>(scales_b.data_ptr());

  typename Gemm::Arguments args(cutlass::gemm::GemmUniversalMode::kGemm,  // Mode
                                {m, n, k},                                // Problem size
                                1,                                        // Split-k factor
                                {},                                       // Epilogue args
                                ptr_a,                                    // a pointer
                                ptr_b,                                    // b pointer
                                nullptr,                                  // c pointer (unused)
                                nullptr,                                  // d pointer (unused)
                                m * k,                                    // batch stride a (unused)
                                n * k,                                    // batch stride b (unused)
                                m * n,                                    // batch stride c (unused)
                                m * n,                                    // batch stride d (unused)
                                lda,                                      // stride a
                                ldb,                                      // stride b
                                ldc,                                      // stride c (unused)
                                ldc);                                     // stride d (unused)
  if constexpr (WithBias) {
    args.epilogue = {{
                         {
                             {},  // Accumulator
                             {ptr_scales_b, ElementComputeEpilogue(0), {_0{}, _1{}, _0{}}},
                             {}  // Multiplies
                         },
                         {ptr_scales_a, ElementComputeEpilogue(0), {_1{}, _0{}, _0{}}},
                         {ptr_bias, ElementOutput(0), {_0{}, _1{}, _0{}}},
                         {}  // Multiplies
                     },
                     {ptr_d, {n, _1{}, _0{}}}};
  } else {
    args.epilogue = {{
                         {
                             {},  // Accumulator
                             {ptr_scales_b, ElementComputeEpilogue(0), {_0{}, _1{}, _0{}}},
                             {}  // Multiplies
                         },
                         {ptr_scales_a, ElementComputeEpilogue(0), {_1{}, _0{}, _0{}}},
                         {}  // Multiplies
                     },
                     {ptr_d, {n, _1{}, _0{}}}};
  }

  return args;
}

template <typename Gemm, bool WithBias>
void launch_sm89_fp8_scaled_mm(torch::Tensor& out, const torch::Tensor& a, const torch::Tensor& b,
                               const torch::Tensor& scales_a, const torch::Tensor& scales_b,
                               const c10::optional<torch::Tensor>& bias) {
  auto args = prepare_sm89_fp8_args<Gemm, WithBias>(out, a, b, scales_a, scales_b, bias);
  Gemm gemm_op;

  size_t workspace_size = gemm_op.get_workspace_size(args);
  auto const workspace_options = torch::TensorOptions().dtype(torch::kUInt8).device(a.device());
  auto workspace = torch::empty(workspace_size, workspace_options);
  auto stream = at::cuda::getCurrentCUDAStream(a.get_device());

  auto can_implement = gemm_op.can_implement(args);
  TORCH_CHECK(can_implement == cutlass::Status::kSuccess)

  auto status = gemm_op(args, workspace.data_ptr(), stream);
  TORCH_CHECK(status == cutlass::Status::kSuccess)
}

template <typename OutType, typename CtaShape, typename WarpShape, int Stages>
void sm89_fp8_dispatch_bias(torch::Tensor& out, const torch::Tensor& a, const torch::Tensor& b,
                            const torch::Tensor& scales_a, const torch::Tensor& scales_b,
                            const c10::optional<torch::Tensor>& bias) {
  using ElementInput = cutlass::float_e4m3_t;
  using ElementOutput = OutType;
  using AccumElementType = float;
  if (bias) {
    using Gemm = typename DeviceGemmFp8RowwiseSm89<ElementInput, ElementOutput, AccumElementType, CtaShape, WarpShape,
                                                   Stages, true>::Gemm;
    return launch_sm89_fp8_scaled_mm<Gemm, true>(out, a, b, scales_a, scales_b, bias);
  } else {
    using Gemm = typename DeviceGemmFp8RowwiseSm89<ElementInput, ElementOutput, AccumElementType, CtaShape, WarpShape,
                                                   Stages, false>::Gemm;
    return launch_sm89_fp8_scaled_mm<Gemm, false>(out, a, b, scales_a, scales_b, bias);
  }
}

template <typename OutType>
void sm89_fp8_dispatch_shape(torch::Tensor& out, const torch::Tensor& a, const torch::Tensor& b,
                             const torch::Tensor& scales_a, const torch::Tensor& scales_b,
                             const c10::optional<torch::Tensor>& bias) {
  uint32_t const m = a.size(0);
  uint32_t const n = out.size(1);

  if (m == 1) {
    if (n <= 8192) {
      return sm89_fp8_dispatch_bias<OutType, cutlass::gemm::GemmShape<16, 64, 128>,
                                    cutlass::gemm::GemmShape<16, 64, 64>, 7>(out, a, b, scales_a, scales_b, bias);
    } else {
      return sm89_fp8_dispatch_bias<OutType, cutlass::gemm::GemmShape<32, 64, 128>,
                                    cutlass::gemm::GemmShape<16, 64, 64>, 5>(out, a, b, scales_a, scales_b, bias);
    }
  } else if (m <= 16) {
    // M in (1, 16]
    if (n <= 8192) {
      return sm89_fp8_dispatch_bias<OutType, cutlass::gemm::GemmShape<16, 64, 128>,
                                    cutlass::gemm::GemmShape<16, 64, 64>, 4>(out, a, b, scales_a, scales_b, bias);
    } else if (n <= 16384) {
      return sm89_fp8_dispatch_bias<OutType, cutlass::gemm::GemmShape<32, 64, 128>,
                                    cutlass::gemm::GemmShape<16, 64, 64>, 5>(out, a, b, scales_a, scales_b, bias);
    } else {
      return sm89_fp8_dispatch_bias<OutType, cutlass::gemm::GemmShape<16, 64, 128>,
                                    cutlass::gemm::GemmShape<16, 64, 64>, 7>(out, a, b, scales_a, scales_b, bias);
    }
  } else if (m <= 64) {
    // M in (16, 64]
    if (n <= 16384) {
      return sm89_fp8_dispatch_bias<OutType, cutlass::gemm::GemmShape<32, 64, 128>,
                                    cutlass::gemm::GemmShape<16, 64, 64>, 7>(out, a, b, scales_a, scales_b, bias);
    } else {
      return sm89_fp8_dispatch_bias<OutType, cutlass::gemm::GemmShape<16, 64, 128>,
                                    cutlass::gemm::GemmShape<16, 64, 64>, 7>(out, a, b, scales_a, scales_b, bias);
    }
  } else if (m <= 128) {
    // M in (64, 128]
    if (n <= 8192) {
      return sm89_fp8_dispatch_bias<OutType, cutlass::gemm::GemmShape<64, 64, 128>,
                                    cutlass::gemm::GemmShape<32, 64, 64>, 4>(out, a, b, scales_a, scales_b, bias);
    } else if (n <= 16384) {
      return sm89_fp8_dispatch_bias<OutType, cutlass::gemm::GemmShape<64, 64, 128>,
                                    cutlass::gemm::GemmShape<32, 64, 64>, 5>(out, a, b, scales_a, scales_b, bias);
    } else {
      return sm89_fp8_dispatch_bias<OutType, cutlass::gemm::GemmShape<32, 64, 128>,
                                    cutlass::gemm::GemmShape<16, 64, 64>, 5>(out, a, b, scales_a, scales_b, bias);
    }
  } else if (m <= 256) {
    // M in (128, 256]
    if (n <= 8192) {
      return sm89_fp8_dispatch_bias<OutType, cutlass::gemm::GemmShape<128, 64, 64>,
                                    cutlass::gemm::GemmShape<64, 32, 64>, 5>(out, a, b, scales_a, scales_b, bias);
    } else if (n <= 16384) {
      return sm89_fp8_dispatch_bias<OutType, cutlass::gemm::GemmShape<64, 128, 64>,
                                    cutlass::gemm::GemmShape<64, 32, 64>, 7>(out, a, b, scales_a, scales_b, bias);
    } else {
      return sm89_fp8_dispatch_bias<OutType, cutlass::gemm::GemmShape<128, 64, 128>,
                                    cutlass::gemm::GemmShape<64, 32, 128>, 4>(out, a, b, scales_a, scales_b, bias);
    }
  } else if (m <= 512) {
    // M in (256, 512)
    if (n <= 16384) {
      return sm89_fp8_dispatch_bias<OutType, cutlass::gemm::GemmShape<128, 128, 64>,
                                    cutlass::gemm::GemmShape<64, 32, 64>, 2>(out, a, b, scales_a, scales_b, bias);
    } else {
      return sm89_fp8_dispatch_bias<OutType, cutlass::gemm::GemmShape<128, 128, 64>,
                                    cutlass::gemm::GemmShape<64, 32, 64>, 4>(out, a, b, scales_a, scales_b, bias);
    }
  } else {
    // M in (512, inf)
    if (n <= 8192) {
      return sm89_fp8_dispatch_bias<OutType, cutlass::gemm::GemmShape<128, 128, 64>,
                                    cutlass::gemm::GemmShape<64, 32, 64>, 3>(out, a, b, scales_a, scales_b, bias);
    } else {
      return sm89_fp8_dispatch_bias<OutType, cutlass::gemm::GemmShape<128, 128, 64>,
                                    cutlass::gemm::GemmShape<64, 32, 64>, 2>(out, a, b, scales_a, scales_b, bias);
    }
  }
}
#endif

#if defined CUDA_VERSION && CUDA_VERSION >= 12000
template <typename ElementType, typename OutElementType, typename AccumElementType, typename CTAShape,
          typename ClusterShape, typename MainloopScheduleType, typename EpilogueScheduleType,
          typename TileSchedulerType = void, bool WithBias = false>
struct DeviceGemmFp8RowwiseSm90 {
  static_assert(std::is_same_v<ElementType, cutlass::float_e4m3_t>, "ElementType must be FP8(e4m3)");

  // A matrix configuration
  using ElementA = ElementType;               // Element type for A matrix operand
  using LayoutA = cutlass::layout::RowMajor;  // Layout type for A matrix operand
  static constexpr int AlignmentA =
      128 / cutlass::sizeof_bits<ElementA>::value;  // Memory access granularity/alignment of A
                                                    // matrix in units of elements (up to 16 bytes)

  // B matrix configuration
  using ElementB = ElementType;                  // Element type for B matrix operand
  using LayoutB = cutlass::layout::ColumnMajor;  // Layout type for B matrix operand
  static constexpr int AlignmentB =
      128 / cutlass::sizeof_bits<ElementB>::value;  // Memory access granularity/alignment of B
                                                    // matrix in units of elements (up to 16 bytes)

  // C/D matrix configuration
  using ElementC = void;                      // Element type for C matrix operands
  using LayoutC = cutlass::layout::RowMajor;  // Layout type for C matrix operands
  static constexpr int AlignmentC =
      128 / cutlass::sizeof_bits<OutElementType>::value;  // Memory access granularity/alignment of C matrices in
                                                          // units of elements (up to 16 bytes)

  // Output matrix configuration
  using ElementOutput = OutElementType;            // Element type for output matrix operands
  using LayoutOutput = cutlass::layout::RowMajor;  // Layout type for output matrix operands
  static constexpr int AlignmentOutput = 128 / cutlass::sizeof_bits<ElementOutput>::value;

  // // Auxiliary matrix configuration and other fusion types
  // using ElementBias = float;

  // Multiply-accumulate blocking/pipelining details
  using ElementAccumulator = AccumElementType;  // Element type for internal accumulation
  using ElementCompute = float;                 // Element type for compute
  using ElementComputeEpilogue = float;
  using ArchTag = cutlass::arch::Sm90;  // Tag indicating the minimum SM that supports the intended feature
  using OperatorClass = cutlass::arch::OpClassTensorOp;  // Operator class tag
  using TileShape = CTAShape;                            // Threadblock-level tile size

  static constexpr bool PONG = false;
  static constexpr bool FAST_ACCUM = true;
  static constexpr bool USE_BIAS = false;

  using StageCountType = cutlass::gemm::collective::StageCountAuto;      // Stage count maximized
                                                                         // based on the tile size
  using KernelSchedule = cutlass::gemm::collective::KernelScheduleAuto;  // Kernel to launch based on the default
                                                                         // setting in the Collective Builder
  // Implement rowwise scaling epilogue.
  using XScale =
      cutlass::epilogue::fusion::Sm90ColBroadcast<0, TileShape, ElementComputeEpilogue, ElementComputeEpilogue,
                                                  cute::Stride<cute::Int<1>, cute::Int<0>, cute::Int<0>>>;

  using WScale =
      cutlass::epilogue::fusion::Sm90RowBroadcast<0, TileShape, ElementComputeEpilogue, ElementComputeEpilogue,
                                                  cute::Stride<cute::Int<0>, cute::Int<1>, cute::Int<0>>>;

  using Bias = cutlass::epilogue::fusion::Sm90RowBroadcast<0, TileShape, ElementOutput, ElementOutput,
                                                           cute::Stride<cute::Int<0>, cute::Int<1>, cute::Int<0>>>;

  using Accum = cutlass::epilogue::fusion::Sm90AccFetch;

  using Compute0 = cutlass::epilogue::fusion::Sm90Compute<cutlass::multiplies,
                                                          ElementComputeEpilogue,  // First stage output type.
                                                          ElementComputeEpilogue,  // First stage input types.
                                                          cutlass::FloatRoundStyle::round_to_nearest>;

  using EVTCompute0 = cutlass::epilogue::fusion::Sm90EVT<Compute0, WScale, Accum>;

  using Compute1 = cutlass::epilogue::fusion::Sm90Compute<cutlass::multiplies, ElementOutput,
                                                          ElementComputeEpilogue,  // Second stage input types.
                                                          cutlass::FloatRoundStyle::round_to_nearest>;

  using EVTCompute1 = cutlass::epilogue::fusion::Sm90EVT<Compute1, XScale, EVTCompute0>;

  // With bias
  using ComputeWithBias =
      cutlass::epilogue::fusion::Sm90Compute<cutlass::multiply_add, ElementOutput, ElementComputeEpilogue,
                                             cutlass::FloatRoundStyle::round_to_nearest>;
  using EVTComputeWithBias = cutlass::epilogue::fusion::Sm90EVT<ComputeWithBias, XScale, EVTCompute0, Bias>;

  using EpilogueEVT = typename cutlass::platform::conditional<WithBias, EVTComputeWithBias, EVTCompute1>::type;

  using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
      cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, TileShape, ClusterShape,
      cutlass::epilogue::collective::EpilogueTileAuto, ElementAccumulator, ElementComputeEpilogue, ElementC, LayoutC,
      AlignmentC, ElementOutput, LayoutOutput, AlignmentOutput, cutlass::epilogue::TmaWarpSpecialized,
      EpilogueEVT>::CollectiveOp;

  using DefaultSchedule = cutlass::gemm::KernelTmaWarpSpecialized;
  using PongSchedule = cutlass::gemm::KernelTmaWarpSpecializedPingpong;
  using FastDefaultSchedule = cutlass::gemm::KernelTmaWarpSpecializedFP8FastAccum;
  using FastPongSchedule = cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum;

  using SlowAccum = DefaultSchedule;
  using FastAccum = FastPongSchedule;  // Default apply Pingpong

  using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder<
      ArchTag, OperatorClass, ElementA, LayoutA, AlignmentA, ElementB, LayoutB, AlignmentB, ElementAccumulator,
      TileShape, ClusterShape,
      cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(
          sizeof(typename CollectiveEpilogue::SharedStorage))>,
      MainloopScheduleType>::CollectiveOp;

  using GemmKernel = cutlass::gemm::kernel::GemmUniversal<Shape<int, int, int, int>,  // Indicates ProblemShape
                                                          CollectiveMainloop, CollectiveEpilogue, TileSchedulerType>;

  using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
};

template <typename Gemm, bool WithBias>
typename Gemm::Arguments prepare_sm90_fp8_args(torch::Tensor& out, const torch::Tensor& a, const torch::Tensor& b,
                                               const torch::Tensor& scales_a, const torch::Tensor& scales_b,
                                               const c10::optional<torch::Tensor>& bias) {
  using ElementT = typename Gemm::ElementA;
  using ElementOutput = typename Gemm::ElementD;
  using ElementComputeEpilogue = float;
  using StrideA = typename Gemm::GemmKernel::StrideA;
  using StrideB = typename Gemm::GemmKernel::StrideB;
  using StrideC = typename Gemm::GemmKernel::StrideC;
  using StrideD = typename Gemm::GemmKernel::StrideD;

  int32_t m = a.size(0);
  int32_t n = b.size(1);
  int32_t k = a.size(1);
  ElementT const* ptr_a = reinterpret_cast<ElementT const*>(a.data_ptr());
  ElementT const* ptr_b = reinterpret_cast<ElementT const*>(b.data_ptr());
  ElementOutput const* ptr_bias = nullptr;
  if constexpr (WithBias) {
    TORCH_CHECK(bias.has_value())
    ptr_bias = reinterpret_cast<ElementOutput const*>(bias.value().data_ptr());
  }
  ElementOutput* ptr_d = reinterpret_cast<ElementOutput*>(out.data_ptr());
  ElementComputeEpilogue const* ptr_scales_a = reinterpret_cast<ElementComputeEpilogue const*>(scales_a.data_ptr());
  ElementComputeEpilogue const* ptr_scales_b = reinterpret_cast<ElementComputeEpilogue const*>(scales_b.data_ptr());

  StrideA stride_a = cutlass::make_cute_packed_stride(StrideA{}, make_shape(m, k, 1));
  StrideB stride_b = cutlass::make_cute_packed_stride(StrideB{}, make_shape(n, k, 1));
  StrideC stride_c;
  StrideD stride_d = cutlass::make_cute_packed_stride(StrideD{}, make_shape(m, n, 1));
  typename Gemm::Arguments args = {cutlass::gemm::GemmUniversalMode::kGemm,
                                   {m, n, k, 1},
                                   {ptr_a, stride_a, ptr_b, stride_b},
                                   {{},  // epilogue.thread
                                    nullptr,
                                    stride_c,
                                    ptr_d,
                                    stride_d}};
  if constexpr (WithBias) {
    args.epilogue.thread = {
        {ptr_scales_a},
        {
            {ptr_scales_b},
            {},  // Accumulator
            {}   // Multiplies
        },
        {ptr_bias},
        {},  // Multiplies
    };
  } else {
    args.epilogue.thread = {
        {ptr_scales_a},
        {
            {ptr_scales_b},
            {},  // Accumulator
            {}   // Multiplies
        },
        {},  // Multiplies
    };
  }

  return args;
}

template <typename Gemm, bool WithBias>
void launch_sm90_fp8_scaled_mm(torch::Tensor& out, const torch::Tensor& a, const torch::Tensor& b,
                               const torch::Tensor& scales_a, const torch::Tensor& scales_b,
                               const c10::optional<torch::Tensor>& bias) {
  auto args = prepare_sm90_fp8_args<Gemm, WithBias>(out, a, b, scales_a, scales_b, bias);
  Gemm gemm_op;

  size_t workspace_size = gemm_op.get_workspace_size(args);
  auto const workspace_options = torch::TensorOptions().dtype(torch::kUInt8).device(a.device());
  auto workspace = torch::empty(workspace_size, workspace_options);
  auto stream = at::cuda::getCurrentCUDAStream(a.get_device());

  auto can_implement = gemm_op.can_implement(args);
  TORCH_CHECK(can_implement == cutlass::Status::kSuccess)

  auto status = gemm_op.run(args, workspace.data_ptr(), stream);

  TORCH_CHECK(status == cutlass::Status::kSuccess)
}

template <typename OutType, typename CTAShape, typename ClusterShape, typename MainloopScheduleType,
          typename TileSchedulerType>
void sm90_fp8_dispatch_bias(torch::Tensor& out, const torch::Tensor& a, const torch::Tensor& b,
                            const torch::Tensor& scales_a, const torch::Tensor& scales_b,
                            const c10::optional<torch::Tensor>& bias, bool fast_accum = true,
                            bool use_persistent = false) {
  using ElementInput = cutlass::float_e4m3_t;
  using ElementOutput = OutType;
  using AccumElementType = float;
  using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized;

  if (bias) {
    using Gemm =
        typename DeviceGemmFp8RowwiseSm90<ElementInput, ElementOutput, AccumElementType, CTAShape, ClusterShape,
                                          MainloopScheduleType, EpilogueScheduleType, TileSchedulerType, true>::Gemm;
    return launch_sm90_fp8_scaled_mm<Gemm, true>(out, a, b, scales_a, scales_b, bias);
  } else {
    using Gemm =
        typename DeviceGemmFp8RowwiseSm90<ElementInput, ElementOutput, AccumElementType, CTAShape, ClusterShape,
                                          MainloopScheduleType, EpilogueScheduleType, TileSchedulerType, false>::Gemm;
    return launch_sm90_fp8_scaled_mm<Gemm, false>(out, a, b, scales_a, scales_b, bias);
  }
}

template <typename OutType>
void sm90_fp8_dispatch_shape(torch::Tensor& out, const torch::Tensor& a, const torch::Tensor& b,
                             const torch::Tensor& scales_a, const torch::Tensor& scales_b,
                             const c10::optional<torch::Tensor>& bias) {
  uint32_t const m = a.size(0);
  using FastPingpongScheduler = cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum;
  using FastBasicScheduler = cutlass::gemm::KernelTmaWarpSpecializedFP8FastAccum;
  using PersistentTileScheduler = cutlass::gemm::PersistentScheduler;
  using BasicTileScheduler = void;
  if (m <= 1) {
    return sm90_fp8_dispatch_bias<OutType, Shape<_64, _64, _128>, Shape<_1, _8, _1>, FastBasicScheduler,
                                  BasicTileScheduler>(out, a, b, scales_a, scales_b, bias);
  }
  if (m <= 64) {
    // m in [1, 64]
    return sm90_fp8_dispatch_bias<OutType, Shape<_64, _64, _128>, Shape<_1, _4, _1>, FastPingpongScheduler,
                                  PersistentTileScheduler>(out, a, b, scales_a, scales_b, bias);
  } else if (m <= 256) {
    // m in (64, 256]
    return sm90_fp8_dispatch_bias<OutType, Shape<_64, _64, _128>, Shape<_1, _1, _1>, FastPingpongScheduler,
                                  PersistentTileScheduler>(out, a, b, scales_a, scales_b, bias);
  } else if (m <= 1024) {
    // m in (256, 1024]
    return sm90_fp8_dispatch_bias<OutType, Shape<_128, _128, _128>, Shape<_1, _1, _1>, FastPingpongScheduler,
                                  PersistentTileScheduler>(out, a, b, scales_a, scales_b, bias);
  } else {
    // m in (1024, inf)
    return sm90_fp8_dispatch_bias<OutType, Shape<_128, _128, _128>, Shape<_2, _1, _1>, FastPingpongScheduler,
                                  PersistentTileScheduler>(out, a, b, scales_a, scales_b, bias);
  }
}
#endif

torch::Tensor fp8_scaled_mm(const torch::Tensor& mat_a, const torch::Tensor& mat_b, const torch::Tensor& scales_a,
                            const torch::Tensor& scales_b, const torch::Dtype& out_dtype,
                            const c10::optional<torch::Tensor>& bias) {
  TORCH_CHECK(mat_a.is_cuda(), "mat_a must be a CUDA tensor");
  TORCH_CHECK(mat_b.is_cuda(), "mat_b must be a CUDA tensor");
  TORCH_CHECK(mat_a.dim() == 2, "mat_a must be a 2D tensor");
  TORCH_CHECK(mat_b.dim() == 2, "mat_b must be a 2D tensor");
  TORCH_CHECK(mat_a.stride(1) == 1, "mat_a must be a row major tensor");
  TORCH_CHECK(mat_b.stride(0) == 1, "mat_a must be a column major tensor");
  TORCH_CHECK(mat_a.size(1) == mat_b.size(0), "mat_a and mat_b shapes cannot be multiplied");

  TORCH_CHECK((mat_a.size(1) * mat_a.element_size()) % 16 == 0,
              "mat_a must be multiple of 16 bytes for memory alignment");
  TORCH_CHECK((mat_b.size(0) * mat_b.element_size()) % 16 == 0,
              "mat_b must be multiple of 16 bytes for memory alignment");
  TORCH_CHECK(mat_a.scalar_type() == torch::kFloat8_e4m3fn, "mat_a must be Float8_e4m3fn");
  TORCH_CHECK(mat_b.scalar_type() == torch::kFloat8_e4m3fn, "mat_b must be Float8_e4m3fn");
  TORCH_CHECK(out_dtype == torch::kHalf || out_dtype == torch::kBFloat16, "out_dtype must be Half or BFloat16");

  TORCH_CHECK(scales_a.numel() == mat_a.size(0), "size of scales_a is not matched");
  TORCH_CHECK(scales_b.numel() == mat_b.size(1), "size of scales_b is not matched");
  TORCH_CHECK(scales_a.is_contiguous(), "scales_a must be contiguous");
  TORCH_CHECK(scales_b.is_contiguous(), "scales_b msut be contiguous");
  TORCH_CHECK(scales_a.scalar_type() == torch::kFloat32, "scales_a must be Float32");
  TORCH_CHECK(scales_b.scalar_type() == torch::kFloat32, "scales_b must be Float32");

  if (bias) {
    TORCH_CHECK(bias->numel() == mat_b.size(1), "size of bias is not matched");
    TORCH_CHECK(bias->is_contiguous(), "bias must be contiguous");
    TORCH_CHECK(bias->dtype() == out_dtype, "bias dtype must match output dtype");
  }

  torch::Tensor out = torch::empty({mat_a.size(0), mat_b.size(1)}, mat_a.options().dtype(out_dtype));
  TORCH_CHECK((out.size(1) * out.element_size()) % 16 == 0, "out must be multiple of 16 bytes for memory alignment");

  auto sm_version = getSMVersion();

#if defined CUDA_VERSION && CUDA_VERSION >= 12000
  if (sm_version >= 90) {
    if (out_dtype == torch::kBFloat16) {
      sm90_fp8_dispatch_shape<cutlass::bfloat16_t>(out, mat_a, mat_b, scales_a, scales_b, bias);
    } else {
      sm90_fp8_dispatch_shape<cutlass::half_t>(out, mat_a, mat_b, scales_a, scales_b, bias);
    }
    return out;
  }
#endif

#if defined CUDA_VERSION && CUDA_VERSION >= 12040
  if (sm_version == 89) {
    if (out_dtype == torch::kBFloat16) {
      sm89_fp8_dispatch_shape<cutlass::bfloat16_t>(out, mat_a, mat_b, scales_a, scales_b, bias);
    } else {
      sm89_fp8_dispatch_shape<cutlass::half_t>(out, mat_a, mat_b, scales_a, scales_b, bias);
    }
    return out;
  }
#endif

  TORCH_CHECK_NOT_IMPLEMENTED(false, "No implemented fp8_scaled_mm for current compute capability: ", sm_version);
}
